Author: Fernando Felix do Nascimento Junior

Last update: 24/06/2019

Addressable market challange

This notebook is divided in the following topics:

Summary:

- Load libs and modules
- Load raw data sets
- Split actual and addressable customers
- EDA
    - Descriptive statistics analsysis
    - Customer type size analysis
    - Outlier analysis
    - Univariate distribution analysis
- Model building
    - Customer segmentation
    - Decision-tree classifier
    - Customer scorer (pearson similiarity)
- Generate Deliverables
    - Deliverable 1
    - Deliverable 2
    - Sanity Check
- Improvements (TODO)
- Cluster based classifier and scorer (deprecated)

It runs on top of IBM Watson Studio with the following hardware config:

- 2 Executors: 1 vCPU and 4 GB RAM, Driver: 1 vCPU and 4 GB RAM

Link (Private):

Load libs

In [ ]:
# requirements.txt
!pip install --upgrade wget
!pip install --upgrade plotly
Waiting for a Spark session to start...
Spark Initialization Done! ApplicationId = app-20190624225824-0001
KERNEL_ID = 189f1449-94e5-4fe4-b9bf-d9bb87db184a
Collecting wget
  Downloading https://files.pythonhosted.org/packages/47/6a/62e288da7bcda82b935ff0c6cfe542970f04e29c756b0e147251b2fb251f/wget-3.2.zip
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... done
  Stored in directory: /home/spark/shared/.cache/pip/wheels/40/15/30/7d8f7cea2902b4db79e3fea550d7d7b85ecb27ef992b618f3f
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
Collecting plotly
  Downloading https://files.pythonhosted.org/packages/ff/75/3982bac5076d0ce6d23103c03840fcaec90c533409f9d82c19f54512a38a/plotly-3.10.0-py2.py3-none-any.whl (41.5MB)
    100% |################################| 41.5MB 505kB/s eta 0:00:01
Collecting retrying>=1.3.3 (from plotly)
  Downloading https://files.pythonhosted.org/packages/44/ef/beae4b4ef80902f22e3af073397f079c96969c69b2c7d52a57ea9ae61c9d/retrying-1.3.3.tar.gz
Collecting nbformat>=4.2 (from plotly)
  Downloading https://files.pythonhosted.org/packages/da/27/9a654d2b6cc1eaa517d1c5a4405166c7f6d72f04f6e7eea41855fe808a46/nbformat-4.4.0-py2.py3-none-any.whl (155kB)
    100% |################################| 163kB 4.6MB/s eta 0:00:01
Collecting six (from plotly)
  Downloading https://files.pythonhosted.org/packages/73/fb/00a976f728d0d1fecfe898238ce23f502a721c0ac0ecfedb80e0d88c64e9/six-1.12.0-py2.py3-none-any.whl
Collecting pytz (from plotly)
  Downloading https://files.pythonhosted.org/packages/3d/73/fe30c2daaaa0713420d0382b16fbb761409f532c56bdcc514bf7b6262bb6/pytz-2019.1-py2.py3-none-any.whl (510kB)
    100% |################################| 512kB 3.9MB/s eta 0:00:01
Collecting requests (from plotly)
  Downloading https://files.pythonhosted.org/packages/51/bd/23c926cd341ea6b7dd0b2a00aba99ae0f828be89d72b2190f27c11d4b7fb/requests-2.22.0-py2.py3-none-any.whl (57kB)
    100% |################################| 61kB 2.9MB/s eta 0:00:01
Collecting decorator>=4.0.6 (from plotly)
  Downloading https://files.pythonhosted.org/packages/5f/88/0075e461560a1e750a0dcbf77f1d9de775028c37a19a346a6c565a257399/decorator-4.4.0-py2.py3-none-any.whl
Collecting traitlets>=4.1 (from nbformat>=4.2->plotly)
  Downloading https://files.pythonhosted.org/packages/93/d6/abcb22de61d78e2fc3959c964628a5771e47e7cc60d53e9342e21ed6cc9a/traitlets-4.3.2-py2.py3-none-any.whl (74kB)
    100% |################################| 81kB 3.5MB/s eta 0:00:01
Collecting jsonschema!=2.5.0,>=2.4 (from nbformat>=4.2->plotly)
  Downloading https://files.pythonhosted.org/packages/aa/69/df679dfbdd051568b53c38ec8152a3ab6bc533434fc7ed11ab034bf5e82f/jsonschema-3.0.1-py2.py3-none-any.whl (54kB)
    100% |################################| 61kB 3.4MB/s eta 0:00:01
Collecting ipython-genutils (from nbformat>=4.2->plotly)
  Downloading https://files.pythonhosted.org/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl
Collecting jupyter-core (from nbformat>=4.2->plotly)
  Downloading https://files.pythonhosted.org/packages/e6/25/6ffb0f6e57fa6ef5d2f814377133b361b42a6dd39105f4885a4f1666c2c3/jupyter_core-4.5.0-py2.py3-none-any.whl (78kB)
    100% |################################| 81kB 3.5MB/s eta 0:00:01
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 (from requests->plotly)
  Downloading https://files.pythonhosted.org/packages/e6/60/247f23a7121ae632d62811ba7f273d0e58972d75e58a94d329d51550a47d/urllib3-1.25.3-py2.py3-none-any.whl (150kB)
    100% |################################| 153kB 4.3MB/s eta 0:00:01
Collecting chardet<3.1.0,>=3.0.2 (from requests->plotly)
  Downloading https://files.pythonhosted.org/packages/bc/a9/01ffebfb562e4274b6487b4bb1ddec7ca55ec7510b22e4c51f14098443b8/chardet-3.0.4-py2.py3-none-any.whl (133kB)
    100% |################################| 143kB 4.3MB/s eta 0:00:01
Collecting idna<2.9,>=2.5 (from requests->plotly)
  Downloading https://files.pythonhosted.org/packages/14/2c/cd551d81dbe15200be1cf41cd03869a46fe7226e7450af7a6545bfc474c9/idna-2.8-py2.py3-none-any.whl (58kB)
    100% |################################| 61kB 3.1MB/s eta 0:00:01
Collecting certifi>=2017.4.17 (from requests->plotly)
  Downloading https://files.pythonhosted.org/packages/69/1b/b853c7a9d4f6a6d00749e94eb6f3a041e342a885b87340b79c1ef73e3a78/certifi-2019.6.16-py2.py3-none-any.whl (157kB)
    100% |################################| 163kB 4.9MB/s eta 0:00:01
Collecting attrs>=17.4.0 (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2->plotly)
  Downloading https://files.pythonhosted.org/packages/23/96/d828354fa2dbdf216eaa7b7de0db692f12c234f7ef888cc14980ef40d1d2/attrs-19.1.0-py2.py3-none-any.whl
Collecting pyrsistent>=0.14.0 (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2->plotly)
  Downloading https://files.pythonhosted.org/packages/68/0b/f514e76b4e074386b60cfc6c8c2d75ca615b81e415417ccf3fac80ae0bf6/pyrsistent-0.15.2.tar.gz (106kB)
    100% |################################| 112kB 5.0MB/s eta 0:00:01
Collecting setuptools (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2->plotly)
  Downloading https://files.pythonhosted.org/packages/ec/51/f45cea425fd5cb0b0380f5b0f048ebc1da5b417e48d304838c02d6288a1e/setuptools-41.0.1-py2.py3-none-any.whl (575kB)
    100% |################################| 583kB 4.5MB/s eta 0:00:01
Building wheels for collected packages: retrying, pyrsistent
  Building wheel for retrying (setup.py) ... done
  Stored in directory: /home/spark/shared/.cache/pip/wheels/d7/a9/33/acc7b709e2a35caa7d4cae442f6fe6fbf2c43f80823d46460c
  Building wheel for pyrsistent (setup.py) ... done
  Stored in directory: /home/spark/shared/.cache/pip/wheels/6b/b9/15/c8c6a1e095a370e8c3273e65a5c982e5cf355dde16d77502f5
Successfully built retrying pyrsistent
tensorflow 1.13.1 requires tensorboard<1.14.0,>=1.13.0, which is not installed.
spyder 3.3.3 requires pyqt5<=5.12; python_version >= "3", which is not installed.
ibm-cos-sdk-core 2.4.3 has requirement urllib3<1.25,>=1.20, but you'll have urllib3 1.25.3 which is incompatible.
botocore 1.12.82 has requirement urllib3<1.25,>=1.20, but you'll have urllib3 1.25.3 which is incompatible.
Installing collected packages: six, retrying, decorator, ipython-genutils, traitlets, attrs, pyrsistent, setuptools, jsonschema, jupyter-core, nbformat, pytz, urllib3, chardet, idna, certifi, requests, plotly
Successfully installed attrs-19.1.0 certifi-2019.6.16 chardet-3.0.4 decorator-4.4.0 idna-2.8 ipython-genutils-0.2.0 jsonschema-3.0.1 jupyter-core-4.5.0 nbformat-4.4.0 plotly-3.10.0 pyrsistent-0.15.2 pytz-2019.1 requests-2.22.0 retrying-1.3.3 setuptools-41.0.1 six-1.12.0 traitlets-4.3.2 urllib3-1.25.3
In [ ]:
# config.py

import ibmos2spark
# @hidden_cell
credentials = {
    'endpoint': 'https://s3-api.us-geo.objectstorage.service.networklayer.com',
    'service_id': 'iam-ServiceId-3c5cb5d7-3787-4432-bd0d-8815e65be261',
    'iam_service_endpoint': 'https://iam.bluemix.net/oidc/token',
    'api_key': 'bW3TRvTFioeaJ2FBZYJ3djHqsWlilR2G1eppRuyPRd7z'
}

configuration_name = 'os_25e3d57cda46474ca2767163f7a810b0_configs'
cos = ibmos2spark.CloudObjectStorage(sc, credentials, configuration_name, 'bluemix_cos')


import wget

import numpy as np
import pandas as pd

from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark import HiveContext

import pyspark.sql.types as T
import pyspark.sql.functions as F

sc = SparkContext.getOrCreate()
sqlContext = HiveContext(sc)
spark = sqlContext.sparkSession

SEED = 27
In [ ]:
# utils.py


def load_csv_as_dataframe(filename):
    df = spark.read\
        .format('org.apache.spark.sql.execution.datasources.csv.CSVFileFormat')\
        .option('header', 'true')\
        .option('inferSchema', 'true')\
        .load(filename)

    df = checkpoint(df, filename + '.parquet')

    return df



def load_dataframe_from_url(link_to_data):
    filename = wget.download(link_to_data)

    df = spark.read\
        .format('org.apache.spark.sql.execution.datasources.csv.CSVFileFormat')\
        .option('header', 'true')\
        .option('inferSchema', 'true')\
        .load(filename)

    df = checkpoint(df, filename + '.parquet')

    return df


def load_dataframe(filename):
    return spark.read.option("mergeSchema", "true").parquet(filename)


def checkpoint(df, filename):
    '''
    Saves a dataframe from memory to a parquet file then read the file
    '''
    df.write.mode('overwrite').parquet(filename)
    return spark.read.option("mergeSchema", "true").parquet(filename)


def inspect_dataframe(df, n=20):
    '''
    Shows data dimension, schema and samples.
    '''
    print('Dimension:', df.count(), len(df.columns))
    print('Schema:')
    df.printSchema()
    print('Sample n={n}:'.format(n=n))
    df.show(n)
    return df


def apply_agg_fn(fn, data, columns):
    '''
    Applies an aggregate function (eg, mean, variance, count, etc.) to a list of columns of a spark dataframe
    '''
    dfagg = data.select(columns).agg(*([fn(c) for c in columns])).toDF(*columns)

    return dfagg.select(*[F.col(c).cast(T.DecimalType(18, 2)).alias(c) for c in dfagg.columns]) # avoid sci notation


def transpose_as_pd(data, index=None):
    '''
    Performs transposition of a spark datraframe and then transforms into a pandas dataframe
    '''
    return data.toPandas().set_index(index).transpose()


def describe(data, columns):
    '''
    Summarizes a spark dataframe applying the following aggregate functions: count, mean, stddev, minmax, var.
    '''
    dfagg = transpose_as_pd(data.select(columns).describe(), index='summary')
    dfvar = apply_agg_fn(F.variance, data, columns).toPandas().rename(index={0: 'var'}).transpose()

    df = pd.concat([dfagg, dfvar], axis=1)
    display(df)
    return data
In [ ]:
# viz.py

# https://plot.ly/python/apache-spark/
# https://plot.ly/python/offline/
import plotly.offline as pyo
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objs as go
from plotly import tools
import plotly.figure_factory as ff

init_notebook_mode(connected=True)


def grouped_bar_plot(df_list, title_list, group_col):
    '''
    https://plot.ly/python/bar-charts/#grouped-bar-chart
    '''
    bar_data = []

    for df, title in zip(df_list, title_list):
        df_by = df.groupBy(group_col).agg(F.count('*').alias('count'))

        #print('Number of items for each group in {title} dataset:'.format(title=title))
        #df_by.show()

        df_by = df_by.collect()

        trace = go.Bar(x = [i[group_col] for i in df_by], y = [i['count'] for i in df_by], name=title)

        bar_data.append(trace)

    layout = go.Layout(barmode='group')

    fig = go.Figure(data=bar_data, layout=layout)
    iplot(fig, filename='grouped-bar')


def boxplots(data, columns):
    '''
    https://plot.ly/python/box-plots/
    https://dataplatform.cloud.ibm.com/exchange/public/entry/view/d80de77f784fed7915c14353512ef14d
    '''
    data_pd = data.select(columns).toPandas()

    traces = []

    for colname in columns:
        traces.append(go.Box(y = data_pd[colname], name = colname))
    
    return iplot(traces)


def dist_plots(data, columns, show_hist=True):
    '''
    - https://plot.ly/python/distplot/
    - https://en.wikipedia.org/wiki/Kernel_density_estimation
    '''
    hist_data = []
    colors = ['#333F44', '#37AA9C', '#94F3E4', '#94F3E4', '#94F3E4', '#94F3E4', '#94F3E4', '#94F3E4', '#94F3E4', '#94F3E4']

    for colname in columns:
        df = data.select(colname).toPandas()[colname]
        hist_data.append(df)

    fig = ff.create_distplot(hist_data, columns, show_hist=show_hist, show_rug=False)
    fig['layout'].update(title='KDE curve plots')

    iplot(fig, filename='Kernel density estimation curve plots')


def line_plot(x, y, title, x_title, y_title, x_range=None, y_range=None):
    '''
    https://plot.ly/python/line-charts/#simple-line-plot
    '''
    xaxis = dict(title = x_title, range=x_range)
    yaxis = dict(title = y_title, range=y_range)
    layout = dict(title = title, xaxis = xaxis, yaxis = yaxis)
    data = [go.Scatter(x = x, y = y)]
    fig = dict(data=data, layout=layout)
    iplot(fig)
In [ ]:
# clustering.py

from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.clustering import KMeans


def tsse(df, feature_cols):
    '''
    Total sum of squared error for multivariate data

    https://stackoverflow.com/a/21385702/4159153
    '''
    df = df.agg(*([(F.variance(F.col(colname)) * (F.count('*') - 1)).alias(colname) for colname in feature_cols]))
    df = df.withColumn('tsse', sum(df[colname] for colname in feature_cols))
    return df.select('tsse')


def cluster_data(dataset, feature_cols, k, max_iter=100, seed=None):
    '''
    https://spark.apache.org/docs/2.2.0/ml-clustering.html#k-means
    https://spark.apache.org/docs/latest/ml-features.html#vectorassembler

    TODO use spark pipeline
    '''
    dataset = dataset.drop('features')
    dataset = dataset.drop('prediction')
    assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')

    dataset = assembler.transform(dataset)

    # Trains a k-means model.
    kmeans = KMeans().setK(k).setMaxIter(max_iter).setSeed(seed)
    model = kmeans.fit(dataset)

    # Make predictions
    dataset = model.transform(dataset)

    # Centroids
    centers = model.clusterCenters()

    # Evaluate clustering by computing Within Set Sum of Squared Errors.
    wssse_score = model.computeCost(dataset)

    dataset = dataset.drop('features')
    return {'model': model, 'predictions': dataset, 'centers': centers, 'wssse_score': wssse_score}


def elbow_curve(df, feature_cols, max_k=90, seed=None):
    '''
    Generates a 2-d tuple containing:
    - a set (list) of number of cluster k >= 1 and k <= max_k. 
    - a list containing a within set sum of squared errors (wssse) for each k >= 2 and k <= max_k
        - if k == 1, it returns the total sum of squared errors of the dataframe[feature_cols]
    '''
    k_list = []
    k_scores = []

    for k in range(1, max_k):
        if k == 1:
            score = tsse(df, feature_cols).collect()[0]['tsse']
            print("Total Sum of Squared Errors k=1; no clustered data: {score}".format(score=score))
        else:
            cluster_results = cluster_data(df, feature_cols, k, seed=seed)
            score = cluster_results['wssse_score']
            print("Within Set Sum of Squared Errors k={k}: {score}".format(k=k, score=score))

        k_list.append(k)
        k_scores.append(score)

    return k_list, k_scores


def elbow_curve_plot(k_list, k_scores, variability_reduction_rate=False):
    '''
    Generates an elbow curve
    '''
    if not variability_reduction_rate:
        return line_plot(k_list, k_scores, 'Elbow curve - WSSSE', 'Number of clusters k', 'Within-clusters Set Sum of Squared Errors - WSSSE')

    k_scores = (pd.Series(k_scores)[0] - pd.Series(k_scores)) / pd.Series(k_scores)[0]
    title = 'Elbow curve - variability reduction rate'
    x_title = 'Number of clusters k'
    y_title = 'WSSSE reduction rate over TSSE: (TSSE - WSSSE / TSSE)'
    return line_plot(k_list, k_scores, title, x_title, y_title)
In [72]:
# classifier.py

import inspect
from functools import reduce

from pyspark.sql import DataFrame
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, VectorIndexer
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.classification import DecisionTreeClassifier


def generate_resample_params_list(data, label_column):
    '''
    Generates the parameters for each label class to use in resample function
    '''
    params_list = data.groupBy(label_column).agg(F.count('*').alias('freq'))

    max_freq = int(params_list.agg(F.mean('freq').alias('max_freq')).collect()[0]['max_freq'])
    print('Max items for each label class:')
    print(max_freq)

    params_list = params_list.withColumn('max_freq', F.lit(max_freq))
    params_list = params_list.withColumn('fraction', F.round((max_freq / F.col('freq')), 2))
    params_list = params_list.withColumn('fraction', F.round(F.col('fraction') * F.col('freq')) / F.col('freq'))
    params_list = params_list.withColumn('with_replacement', F.col('freq') < max_freq)
    params_list = list(i.asDict() for i in params_list.collect())
    return params_list


def resample(data, label_column, seed=None):
    '''
    Performs data resampling according to the frequency of each class of the label to deal with class imbalance.
    It uses both upsampling and downsampling methods:
        - If a frequency of a label class is smaller than the average frequency use upsampling (with substitution)
        - otherwise use downsampling (without substitution).

    The main goal of this method is to avoid noise increasement by copying / adding data randomly via upsampling (pseudo-random).

    Ref:
        https://stackoverflow.com/a/53990745/4159153
    '''
    class_params_list = generate_resample_params_list(data, label_column)

    assert(label_column in data.columns)

    sample_dataframes = []

    for i, class_params in enumerate(class_params_list):
        sample_class = data.filter(F.col(label_column) == class_params[label_column])

        if class_params[label_column] is None or class_params['fraction'] in [None, 0.0]:
            continue

        sample_class = sample_class.sample(class_params['with_replacement'], float(class_params['fraction']), seed=seed)
        sample_dataframes.append(sample_class)

    return reduce(DataFrame.unionAll, sample_dataframes)


def assemble_vector(dataframe, input_cols, output_col):
    '''
    Combines a given list of columns into a single vector column in a a pyspark dataframe

    https://spark.apache.org/docs/latest/ml-features.html#vectorassembler
    '''
    assembler = VectorAssembler(inputCols=input_cols, outputCol=output_col)
    return assembler.transform(dataframe.na.drop())


def build_param_grid_maps(classifier, param_grid_dict):
    '''
    Converts a grid search parameters from python dict format to ParamGridBuilder.build() result

    https://spark.apache.org/docs/2.2.0/ml-tuning.html
    https://spark.apache.org/docs/2.2.0/api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder
    https://spark.apache.org/docs/2.2.0/api/scala/index.html#org.apache.spark.ml.param.ParamMap
    '''
    estimator_param_maps = ParamGridBuilder()

    # https://stackoverflow.com/a/9058322/4159153s
    classifier_attributes = inspect.getmembers(classifier, lambda a:not(inspect.isroutine(a)))
    classifier_attributes = [a for a in classifier_attributes if not(a[0].startswith('__') and a[0].endswith('__'))]
    classifier_attributes = dict(classifier_attributes)

    for param_key, param_value in param_grid_dict.items():
        estimator_param_maps.addGrid(classifier_attributes[param_key], param_value)

    return estimator_param_maps.build()


def train_multiclass_classifier(classifier_cls, training, testing, feature_col_list, label_col, param_grid, metric='accuracy'):
    '''
    Trains a tunned multiclass classifier using cross-validation and grid search params
    '''
    features_col = 'features'
    prediction_col= 'prediction'

    assert(features_col not in training.columns)
    assert(features_col not in testing.columns)
    assert(features_col not in feature_col_list)
    assert(features_col != label_col)
    assert(isinstance(param_grid, dict))

    # include features into a single column vector for training and testing purpose
    training = assemble_vector(training, feature_col_list, features_col)
    testing = assemble_vector(testing, feature_col_list, features_col)

    # training
    classifier = classifier_cls(featuresCol = features_col, labelCol = label_col)
    param_grid_maps = build_param_grid_maps(classifier, param_grid)
    evaluator = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col, metricName=metric) 
    validator = CrossValidator(estimator=classifier, estimatorParamMaps=param_grid_maps, evaluator=evaluator, numFolds=5)
    model = validator.fit(training)

    # evaluation
    testing_predictions = model.transform(testing)
    testing_score = evaluator.evaluate(testing_predictions)

    training_predictions = model.transform(training)
    training_score = evaluator.evaluate(training_predictions)

    best_param_map = model.bestModel.extractParamMap()

    return model, testing_score, training_score, testing_predictions, training_predictions, best_param_map

Load raw data sets

Let's load the raw datasets from customer CRM and Neoway firmographic.

In [ ]:
# customer_crm = load_dataframe_from_url('https://gist.githubusercontent.com/fernandojunior/e30bbab8298fb8a57c78f52079503fd8/raw/0f8995acab136324237b17a670fa083db674462d/customer_CRM_2019-05-17.csv')
# neoway_db = load_dataframe_from_url('https://gist.githubusercontent.com/fernandojunior/e30bbab8298fb8a57c78f52079503fd8/raw/0f8995acab136324237b17a670fa083db674462d/Neoway_database_2019-05-17.csv')
In [88]:
customer_crm = load_csv_as_dataframe('customer_CRM_2019-05-17.csv')
neoway_db = load_csv_as_dataframe('Neoway_database_2019-05-17.csv')

Inspect customer_crm dataset (dimension, schema, sample):

In [8]:
customer_crm = inspect_dataframe(customer_crm, n=3)
Dimension: 2000 1
Schema:
root
 |-- id: integer (nullable = true)

Sample n=3:
+----+
|  id|
+----+
|9796|
|  87|
|6005|
+----+
only showing top 3 rows

Inspect customer_crm dataset (dimension, schema, sample):

In [9]:
neoway_db = inspect_dataframe(neoway_db, n=3)
Dimension: 40001 12
Schema:
root
 |-- id: integer (nullable = true)
 |-- type: string (nullable = true)
 |-- feat_0: double (nullable = true)
 |-- feat_1: double (nullable = true)
 |-- feat_2: double (nullable = true)
 |-- feat_3: double (nullable = true)
 |-- feat_4: double (nullable = true)
 |-- feat_5: double (nullable = true)
 |-- feat_6: double (nullable = true)
 |-- feat_7: double (nullable = true)
 |-- feat_8: double (nullable = true)
 |-- feat_9: double (nullable = true)

Sample n=3:
+-----+-----------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+
|   id|       type|             feat_0|             feat_1|             feat_2|              feat_3|             feat_4|             feat_5|             feat_6|            feat_7|            feat_8|             feat_9|
+-----+-----------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+
|14868| RESTAURANT|-1.0422452219339224|-0.6279247488601141|0.37677205146275683| -1.8517203647128966|-0.9130655710905651|-2.6988866204772095| 2.1063215061794778|0.7014081107181209| 1.184736729893727|-1.2206730744414633|
| 5569|SUPERMARKET|  0.972655692135186|-0.8461329999971695| 1.3205124597799167|-0.28581182386228776| 1.0929431382011188| 0.1200987327500529|-0.7263606945323482| 1.140403609574627|1.1144257050775448| 1.4977868247317494|
|16598|SUPERMARKET| 1.6035354091944014|-0.7056561871488535|  2.038543783302852|  0.4885742931966455| 2.3825082674711116|  1.912406275161759|-1.8865528053198803|0.9483690101441813|1.3589304325235478|  2.168985683732492|
+-----+-----------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+
only showing top 3 rows

Notes:

  • The data size, schema and sample correspond the expected format from CSV files.

Split actual and addressable customers

In this section, we will carry out the following activities:

  • Merge raw datasets into a single dataset (all_customers)
  • Remove duplicate data items (distinct)
  • Identify and count actual customers (addressable == False) and addressable customers (addressable == True)
  • Split the data between actual and addressable customers
In [21]:
# merge data and create new column to identify actual or addresable customer
customer_crm = customer_crm.withColumn('addressable', F.lit(False))
all_customers = neoway_db.join(customer_crm, ['id'], 'left_outer')
all_customers = all_customers.withColumn('addressable', F.coalesce(F.col('addressable'), F.lit(True)))

all_customers = all_customers.distinct() # removendo dados duplicados

all_customers.groupBy('addressable').agg(F.count('*')).show()
+-----------+--------+
|addressable|count(1)|
+-----------+--------+
|       true|   18001|
|      false|    2000|
+-----------+--------+

Notes:

  • Actual customer size: 18001
  • Addressable customer size: 2000

Now, let's split the all_customers dataset into actual_customers and addressable_customers datasets.

In [22]:
# split all customers into actual_customers and addressable_customers
actual_customers = all_customers.filter(F.col('addressable') == False)
addressable_customers = all_customers.filter(F.col('addressable') == True)
In [23]:
# checkpoints to refresh spark DAG
all_customers = checkpoint(all_customers, 'all_customers.parquet')
actual_customers = checkpoint(actual_customers, 'actual_customers.parquet')
addressable_customers = checkpoint(addressable_customers, 'addressable_customers.parquet')

EDA

Before performing EDA, let's identify the features and customer types:

In [24]:
feature_cols = ['feat_0', 'feat_1', 'feat_2', 'feat_3',  'feat_4',  'feat_5', 'feat_6', 'feat_7', 'feat_8', 'feat_9']
customer_type_list = [i.type for i in all_customers.select('type').distinct().filter(F.col('type').isNotNull()).collect()]

print('Feature columns: ', feature_cols)
print('Customer types:', customer_type_list)
Feature columns:  ['feat_0', 'feat_1', 'feat_2', 'feat_3', 'feat_4', 'feat_5', 'feat_6', 'feat_7', 'feat_8', 'feat_9']
Customer types: ['C-STORE', 'CHURCH', 'SUPERMARKET', 'RESTAURANT', 'BAR', 'SCHOOL', 'HAIR SALOON']

Descriptive statistics analsysis

Let's summarize the descriptive statistics for each dataset: actual_customers, addressable_customers.

In [26]:
print('Descriptive statistics for actual customers dataset:')
actual_customers = describe(actual_customers, feature_cols)

print('Descriptive statistics for addressable customers dataset:')
addressable_customers = describe(addressable_customers, feature_cols)
Descriptive statistics for actual customers dataset:
count mean stddev min max var
feat_0 2000 0.9060348559759658 1.274418281800353 -4.081963458636036 5.35443648832571 1.62
feat_1 2000 0.18234375458121313 0.9641054734157618 -2.8645646726925493 3.484946209967047 0.93
feat_2 2000 -1.5922498227813566 1.3139478715268662 -5.41118130614965 5.83429522864201 1.73
feat_3 2000 -0.31912429870844133 1.1595602747469658 -6.618385700773195 4.559022921561447 1.34
feat_4 2000 -1.5175392395959504 1.1933796035510216 -4.9985515783329575 5.874028198476527 1.42
feat_5 2000 0.927768461871796 1.0167261401392471 -3.301503836364102 5.18793127049459 1.03
feat_6 2000 0.25971634764176216 1.0226380416595076 -7.15572661678598 4.37115779512938 1.05
feat_7 2000 -0.8660342752627497 1.178667503750749 -4.273696626476979 4.307985822909862 1.39
feat_8 2000 -0.890540297602544 1.3086982039581163 -5.18510836935514 3.6403758231765346 1.71
feat_9 2000 -0.9357658746494996 1.0636771839147994 -4.904202284024512 3.763376407479275 1.13
Descriptive statistics for addressable customers dataset:
count mean stddev min max var
feat_0 18000 -0.09556021002567075 1.5044133597372051 -5.09084922894823 5.375736317068322 2.26
feat_1 18000 -0.16908276729050353 0.9787610484721231 -3.7920614303893796 3.608480659267811 0.96
feat_2 18000 0.5450102366570703 1.601963838086544 -6.20637652896264 8.584294720612778 2.57
feat_3 18000 -0.8741001166981852 2.120011602969179 -10.697925918813686 8.500174232615521 4.49
feat_4 18000 -0.1855691097791432 1.9047149121050624 -6.459551244702993 8.40725410062145 3.63
feat_5 18000 -0.09890049186292867 1.4993743038511864 -5.353589176056498 6.552535880615502 2.25
feat_6 18000 -0.01572345673429759 1.6664719915716169 -10.400622715851654 7.513861407580052 2.78
feat_7 18000 0.6485546177040741 1.610698851369331 -5.521470706495204 6.665227991937838 2.59
feat_8 18000 0.663232577135456 1.3321626974017562 -6.362184079215283 5.449992626055867 1.77
feat_9 18000 -0.4623455335077495 1.643335434717868 -6.853951143455957 7.034025634528622 2.70

Notes:

  • The features have the same data size for each dataset. This indicates that features don't have missing values (null, nan)
  • The basic statistics of the addressable_customers and actual_customers datasets are similar, except for variance.
  • The mean and variance of actual_customers is close to 0 and 1, respectively. This indicates that the data was standardized.
  • The addressable_customers dataset appeared to have been standardized using z-score parameters (mean, variance) from actual_customers dataset.

Customer type size analysis

Let's analyse the number of customers by type for each dataset: actual customers, addressable customers, all customers.

In [28]:
grouped_bar_plot([actual_customers, addressable_customers, all_customers], ['actual customers', 'addressable customers', 'all customers'], 'type')

Notes:

  • The actual_customers dataset doesn't have all customer types.      - actual_customers types: church, supermarket, restaurant, school      - addressable_customers types: church, supermarket, restaurant, school, c-store, bar, hair saloon
  • The type sizes aren't balanced for both datasets
  • Therefore, in our modeling, we can't perform segmentation based on type field. We will use k-means clustering to find the most fitting segments derived from customer characteristics.

Outlier analysis

Let's perform a simple outlier analysis by applying boxplot for each customer type on merged dataset (actual dataset + addressable dataset).

In [29]:
for customer_type in customer_type_list:
    print('Boxplot for ', customer_type)
    boxplots(all_customers.filter(F.col('type') == customer_type), feature_cols)
Boxplot for  C-STORE
Boxplot for  CHURCH
Boxplot for  SUPERMARKET
Boxplot for  RESTAURANT
Boxplot for  BAR
Boxplot for  SCHOOL
Boxplot for  HAIR SALOON
  • Embora o dataset possua outliers, a distribuicao dos outliers nos boxplots nao parecem ter comportamento estranho, como presenca de dados oriundos de erros ou outliers extremos que possam interferir na estatistica dos dados.
  • Os outliers parecem estar ajudando a discriminar o comportamento de cada tipo de customer. Por exemplo, alguns tipos tem mais outliers em determinadas features
  • Portanto, por enquanto, nenhum tratamento vai ser realizado para remover outliers

Univariate distribution analysis

In [31]:
dist_plots(all_customers.na.drop(), feature_cols, show_hist=False)

Notes:

  • By analyzing the kurtosis and skewness of the curves visually, we can say that in general the features are normally distributed.
    • Some features more than the others.
  • We could apply the Shapiro-Wilk test in each of the features to validate.
  • Therefore, if necessary, we could apply some parametric statistical tests that require the data to be distributed normally [link] (https://www.originlab.com/doc/Origin-Help/Normality-Test)

Model building

Customer segmentation

The main goal of k-means is to decrease the data variability by grouping similar items. To find the optimal number of k, we will use a custom elbow method, which aims to analyze the proportion of how much the WSSE decreases over the TSSE for each k.

In [37]:
k_list, k_scores = elbow_curve(actual_customers, feature_cols, max_k=70, seed=SEED)
elbow_curve_plot(k_list, k_scores, variability_reduction_rate=True)
Total Sum of Squared Errors k=1; no clustered data: 26710.06443133542
Within Set Sum of Squared Errors k=2: 18863.045888392717
Within Set Sum of Squared Errors k=3: 15516.081830950678
Within Set Sum of Squared Errors k=4: 13656.774540794198
Within Set Sum of Squared Errors k=5: 12514.740017401393
Within Set Sum of Squared Errors k=6: 11379.6171044741
Within Set Sum of Squared Errors k=7: 10493.095866647835
Within Set Sum of Squared Errors k=8: 9873.501425902088
Within Set Sum of Squared Errors k=9: 9248.177900066532
Within Set Sum of Squared Errors k=10: 8571.776513051265
Within Set Sum of Squared Errors k=11: 8109.474755593308
Within Set Sum of Squared Errors k=12: 7720.847704933437
Within Set Sum of Squared Errors k=13: 7221.55627073083
Within Set Sum of Squared Errors k=14: 7121.483657117699
Within Set Sum of Squared Errors k=15: 6728.692251238528
Within Set Sum of Squared Errors k=16: 6526.99202546487
Within Set Sum of Squared Errors k=17: 6290.75706109168
Within Set Sum of Squared Errors k=18: 6088.406408578602
Within Set Sum of Squared Errors k=19: 5991.4030519904145
Within Set Sum of Squared Errors k=20: 5714.836294153299
Within Set Sum of Squared Errors k=21: 5573.347893668797
Within Set Sum of Squared Errors k=22: 5439.619109700317
Within Set Sum of Squared Errors k=23: 5276.83473013073
Within Set Sum of Squared Errors k=24: 5178.007584530702
Within Set Sum of Squared Errors k=25: 4973.743713026908
Within Set Sum of Squared Errors k=26: 4884.047605084491
Within Set Sum of Squared Errors k=27: 4778.417411128934
Within Set Sum of Squared Errors k=28: 4687.715477979777
Within Set Sum of Squared Errors k=29: 4627.0631554083375
Within Set Sum of Squared Errors k=30: 4537.459813308225
Within Set Sum of Squared Errors k=31: 4540.172484329018
Within Set Sum of Squared Errors k=32: 4423.35034326994
Within Set Sum of Squared Errors k=33: 4285.911333043882
Within Set Sum of Squared Errors k=34: 4274.494793461957
Within Set Sum of Squared Errors k=35: 4217.912477629837
Within Set Sum of Squared Errors k=36: 4152.399043774064
Within Set Sum of Squared Errors k=37: 4055.4311533375326
Within Set Sum of Squared Errors k=38: 3947.4512508265857
Within Set Sum of Squared Errors k=39: 3938.5658991886667
Within Set Sum of Squared Errors k=40: 3866.6083878490845
Within Set Sum of Squared Errors k=41: 3775.6842133087935
Within Set Sum of Squared Errors k=42: 3752.32716586355
Within Set Sum of Squared Errors k=43: 3762.3506009623247
Within Set Sum of Squared Errors k=44: 3690.2542288397613
Within Set Sum of Squared Errors k=45: 3578.7506830825496
Within Set Sum of Squared Errors k=46: 3638.5892031723697
Within Set Sum of Squared Errors k=47: 3566.0993185854672
Within Set Sum of Squared Errors k=48: 3483.0278326454627
Within Set Sum of Squared Errors k=49: 3439.982429989299
Within Set Sum of Squared Errors k=50: 3453.5386195195106
Within Set Sum of Squared Errors k=51: 3325.5184911375127
Within Set Sum of Squared Errors k=52: 3356.7219232135394
Within Set Sum of Squared Errors k=53: 3333.353055111793
Within Set Sum of Squared Errors k=54: 3235.9964657322994
Within Set Sum of Squared Errors k=55: 3236.377224981873
Within Set Sum of Squared Errors k=56: 3273.1750297971057
Within Set Sum of Squared Errors k=57: 3186.680069516532
Within Set Sum of Squared Errors k=58: 3148.4247791747853
Within Set Sum of Squared Errors k=59: 3073.233075573848
Within Set Sum of Squared Errors k=60: 3106.514018017732
Within Set Sum of Squared Errors k=61: 3012.5695961133247
Within Set Sum of Squared Errors k=62: 2969.3436147355774
Within Set Sum of Squared Errors k=63: 3003.010009661414
Within Set Sum of Squared Errors k=64: 2996.5538616421004
Within Set Sum of Squared Errors k=65: 2899.260052040212
Within Set Sum of Squared Errors k=66: 2904.5868069896746
Within Set Sum of Squared Errors k=67: 2883.5903407164114
Within Set Sum of Squared Errors k=68: 2882.5572249899196
Within Set Sum of Squared Errors k=69: 2793.2168639344063

Notes:

  • As we can see, the elbow curve varies between k> = 10 and k <= 12
  • Considering that 70% is a reasonable thrashold for variability reduction, we will choose k = 11 as the optimal number of clusters

The following figure summarizes the size of the actual customer in each cluster.

In [38]:
actual_customers = actual_customers.drop('cluster')
cluster_results = cluster_data(actual_customers, feature_cols, 11, seed=SEED)
actual_customers = cluster_results['predictions'].withColumnRenamed('prediction', 'cluster')

#boxplots(actual_customers.groupBy('segment').agg(F.count('*').alias('count')), ['count'])
grouped_bar_plot([actual_customers], ['actual customers'], 'cluster')

Note:

  • As we can see, two segments can be considered outliers as they have fewer customer sizes.

Decision-tree classifier

Based on clustered data, we will train a decision tree classisifier:

  • Split training and test datasets
  • handle imbalanced classes hybrid method: upsampling and downsampling (threshold = avg of class freq).
  • analyse cluster label sizes after resampling
  • train a decision tree classifier using cross validation and grid search
  • classfier offline validation using accuracy
In [46]:
label_col = 'cluster'
prediction_label = 'prediction'

(training, testing) = actual_customers.randomSplit([0.7, 0.3], seed=SEED)

training = checkpoint(training, 'training.parquet')
testing = checkpoint(testing, 'testing.parquet')

training = resample(training, label_col, seed=SEED)
training = checkpoint(training, 'training_resampled.parquet')

grouped_bar_plot([training], ['training'], label_col)

param_grid = {
    'maxDepth': [5, 8, 13],
    'maxBins': [13, 21, 34],
    'impurity': ['entropy', 'gini'],
    'seed': [SEED]
}

(model, testing_score, training_score, testing, training, best_param_map) = train_multiclass_classifier(
    DecisionTreeClassifier,
    training,
    testing,
    feature_cols,
    label_col,
    param_grid,
    metric='accuracy')

print('testing_score vs training_score:', testing_score, training_score)
Max items for each label class:
128
testing_score vs training_score: 0.8384879725085911 0.9993108201240524

Notes:

  • The classifier performs well without overfitting and underfitting

Customer scorer

In this section we will perform the following tasks:

  • predict the addressable customers' clusters using previously trained classifier
  • estimate score by applying pearson correlation between predict cluster centroids and features for each addressable customer.
In [73]:
@F.udf("double")
def udf_corr_scoring(point, cluster):
    point = pd.Series([float(i) for i in point])
    corr_series = centers.corrwith(point)
    result = corr_series[int(cluster)]

    return float(result)

# predict addressable_customers' clusters and estimate correlation core 
addressable_customers = addressable_customers.drop('features', 'prediction', 'rawPrediction', 'probability')
addressable_customers = model.transform(assemble_vector(addressable_customers, feature_cols, 'features'))
addressable_customers = addressable_customers.withColumn('score', udf_corr_scoring(F.array(feature_cols), prediction_label))

addressable_customers.show(5)
+-----+-----------+-------------------+-------------------+--------------------+-------------------+--------------------+------------------+--------------------+------------------+--------------------+-------------------+-----------+--------------------+--------------------+--------------------+--------------------+----------+
|   id|       type|             feat_0|             feat_1|              feat_2|             feat_3|              feat_4|            feat_5|              feat_6|            feat_7|              feat_8|             feat_9|addressable|               score|            features|       rawPrediction|         probability|prediction|
+-----+-----------+-------------------+-------------------+--------------------+-------------------+--------------------+------------------+--------------------+------------------+--------------------+-------------------+-----------+--------------------+--------------------+--------------------+--------------------+----------+
| 1433|SUPERMARKET|-1.1893515831276014| 0.6553077741865874|  0.5321357217411693|-0.9047329371040032|-0.37661701735138003|0.5083153079668359|0.025418846862648663|0.1456625868925321|  0.5212049683828303| -2.000341126780125|       true|0.044124067677148424|[-1.1893515831276...|[0.0,0.0,0.0,0.0,...|[0.0,0.0,0.0,0.0,...|       9.0|
| 2442| RESTAURANT|  2.382089613949935|-1.0133302366233687|  1.7104961717111093| 1.2259220449874677|   2.433062781419892|1.6819553878726934| -2.9670982691420624|1.6123003074528977|  0.5766596624272641|  4.072812737600505|       true|  0.5860626380065698|[2.38208961394993...|[0.0,0.0,0.0,0.0,...|[0.0,0.0,0.0,0.0,...|       9.0|
|12895| RESTAURANT|-1.6205384494374864| 0.7065382282683983|  0.2315698097419313|0.17190788224878797|-0.12990643696007778|-0.624523559301969|-0.33374116319943786|0.3999247420045181| -0.3706415508642893|-0.5975042300871813|       true| 0.13671498415679853|[-1.6205384494374...|[0.0,0.0,0.0,0.0,...|[0.0,0.0,0.0,0.0,...|       5.0|
|11736|SUPERMARKET| 2.3364931593372416|-1.0498802490836292|-0.03238424520598936|-0.6520216265369067| -0.5876790396694681|0.5537788160892412| -1.7500478673126412|1.7674978944507493|-0.03749685781425249|  2.102660183272527|       true|  0.7367670282998992|[2.33649315933724...|[92.0,0.0,0.0,0.0...|[1.0,0.0,0.0,0.0,...|       0.0|
|14052| RESTAURANT| 1.7100122012467789|-0.7142546716806903|   2.762110985627179| 1.2309835503006887|   3.189441400920434|2.1431331277128107|  -3.665262670848533| 2.111510930288886|  1.1024012143182214| 3.8881656309904677|       true|   0.752929882120784|[1.71001220124677...|[0.0,0.0,0.0,0.0,...|[0.0,0.0,0.0,0.0,...|       9.0|
+-----+-----------+-------------------+-------------------+--------------------+-------------------+--------------------+------------------+--------------------+------------------+--------------------+-------------------+-----------+--------------------+--------------------+--------------------+--------------------+----------+
only showing top 5 rows

Generate Deliverables

In [93]:
def save_deliberable(df, filename):
    filename = cos.url(filename, 'potentialmarketranking-donotdelete-pr-cej2kccafd4zxc')
    df.repartition(1).write.mode('overwrite').option("header", "true").csv(filename)


def load_deliberable(filename):
    filename = cos.url(filename, 'potentialmarketranking-donotdelete-pr-cej2kccafd4zxc')
    
    return spark.read\
      .format('org.apache.spark.sql.execution.datasources.csv.CSVFileFormat')\
      .option('header', 'true')\
      .load(filename)

Finally, let's generate the challange deliverables.

Deliverable 1

In [78]:
save_deliberable(addressable_customers.select('id'), 'addressable_ids.csv')

Deliverable 2

In [81]:
save_deliberable(training.select('id'), 'training_ids.csv')
save_deliberable(testing.select('id'), 'testing_ids.csv')
save_deliberable(addressable_customers.select('id', 'score').orderBy(F.desc('score')), 'addressable_ranking.csv')
print('OK')
OK

Sanity Check

Let's check if the deliverable datasets were generated and saved correctly.

In [94]:
print('training_ids.csv')
load_deliberable('training_ids.csv').show(5)

print('testing_ids.csv')
load_deliberable('testing_ids.csv').show()

print('addressable_ranking.csv')
load_deliberable('addressable_ranking.csv').show()
training_ids.csv
+----+
|  id|
+----+
| 162|
|1177|
|2665|
|3069|
|4935|
+----+
only showing top 5 rows

testing_ids.csv
+----+
|  id|
+----+
| 974|
| 995|
|1236|
|1469|
|1553|
|1638|
|1814|
|1891|
|1985|
|2049|
|2445|
|3093|
|3174|
|3735|
|3746|
|3953|
|4646|
|4766|
|4984|
|5058|
+----+
only showing top 20 rows

addressable_ranking.csv
+-----+------------------+
|   id|             score|
+-----+------------------+
| 7462|0.9992117688834464|
| 2446|0.9984523503017962|
| 4711|0.9983787872319342|
| 4294|0.9983042582540658|
|12443|0.9982600101852419|
|13375|0.9981577560105647|
|  102|0.9977551605356462|
| 2824|0.9972083640507304|
|17358|0.9971874403260467|
|18764|0.9970584680641514|
|17442|0.9969639429287497|
| 3273|0.9968587467695936|
| 4964|0.9968065026629201|
|10816|0.9967813319537918|
| 8703|0.9967386222473469|
| 9410|0.9967187700464519|
| 6801|0.9967072898304182|
|19899| 0.996675467532045|
|14346|0.9963829451751417|
| 6014|0.9962525095151363|
+-----+------------------+
only showing top 20 rows

Improvements

TODO:

 - Redistribute customers from small clusters into other clusters
 - persona characterization (improve explicability) - analyze centroids to describe and differentiate each cluster behavior
 - create files to store python modules: utils, config, clustering, etc.
 - save models (clustering, classsifier)
 - use spark pipelines
 - compute processing time for clustering and predictions
 - try to use euclidian distance to compute score

Cluster based classifier and scorer (deprecated)

Clustering is not able to classify instances of companies in itself. Instead, a simple classification model was built on top of the result of the clustering (reference). This was conducted with the following algorithm

  1. Split customers c1, c2, . . . , c644 into a training set (80%) and a test set (20%)
  2. Cluster the training-data = actual-customers using the K-means algorithm
  3. Calculate the centroids for each cluster by taking the distance to the data point in the training set
  4. Predict the closest cluster for training-set (actual customers) using pearson similiraty method
  5. Predict the closest cluster for unlabeled companies (addressable customers = testing-set) using pearson similiraty method
  6. Compare the predicted centroids for both training-set and unlabeled companies
  7. Rank addressable customers using pearson similiraty
In [40]:
centers = pd.DataFrame(cluster_results['centers']).T


def assemble_vector(dataframe, input_cols, output_col):
    '''
    Combine a given list of columns into a single vector column in a a pyspark dataframe

    https://spark.apache.org/docs/latest/ml-features.html#vectorassembler
    '''
    assembler = VectorAssembler(inputCols=input_cols, outputCol=output_col)
    return assembler.transform(dataframe.na.drop())


def find_cluster(point):    
    '''
    Find the most correlated cluster of a given point

    1. Compute the pairwise correlation between each row of a clusters' centers (pd.Dataframe) vs a dimensional point (pd.Series).

    2. Rank the correlations

    3. Return a tuple containing: the most correlated cluster of point, correlation estimation.

    Ref:
        - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.corrwith.html
        - https://stackoverflow.com/a/38468711/4159153
        - https://stackoverflow.com/a/35459076/4159153

    TODO: Use euclidian distance to compute score
        ```
        from scipy.spatial import distance
        a = (1, 2, 3)
        b = (4, 5, 6)
        distance.euclidean(a, b)
        ```
    '''
    point = pd.Series([float(i) for i in point])
    corr_series = centers.corrwith(point)
    corr_series = corr_series.sort_values(ascending=False)

    result = list(zip(corr_series.index, corr_series))[0] #  first result from ranking

    return (float(result[0]), float(result[1])) # (cluster, correlation)


@F.udf("double")
def udf_estimate_cluster(point):
    return find_cluster(point)[0]


@F.udf("double")
def udf_estimate_score(point):
    return find_cluster(point)[1]


def predict_cluster(df):
    df = assemble_vector(df, feature_cols, 'features')
    df = df.withColumn('prediction', udf_estimate_cluster(F.col('features')))
    df = df.withColumn('score', udf_estimate_score(F.col('features')))
    df = df.drop('features')

    return df


training = checkpoint(actual_customers.distinct(), 'training.parquet')
testing = checkpoint(addressable_customers.distinct(), 'testing.parquet')

training = predict_cluster(training)
testing = predict_cluster(testing)

print('Actual customers predictions based on cluster centroids correlation')
training.show(3)

print('Addressable customers predictions based on cluster centroids correlation')
testing.show(3)
Actual customers predictions based on cluster centroids correlation
+-----+-----------+-------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+--------------------+-------------------+------------------+-----------+-------+----------+------------------+
|   id|       type|             feat_0|              feat_1|             feat_2|              feat_3|             feat_4|             feat_5|             feat_6|              feat_7|             feat_8|            feat_9|addressable|cluster|prediction|             score|
+-----+-----------+-------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+--------------------+-------------------+------------------+-----------+-------+----------+------------------+
|16656| RESTAURANT|-0.3121369858601416| -1.0950151144230416| 1.1582380877536358|  -3.557199062329223| -1.434006384699189|-2.0291576285820536| 1.6232181755983899|   1.897762084945093| 2.3267727498705573|-1.839375478951836|      false|      6|       6.0|0.9375812235007218|
| 6395| RESTAURANT|-0.4375843640892725|  1.2423605822495392|-2.6902863050871706|  0.9593601995637638|-1.5960187773130137| 0.8963132732691619| 1.1518704700078057|  -2.700683779055528| -2.079477818528634|-2.063190889907343|      false|      1|       4.0|0.9421953641191076|
| 2899|SUPERMARKET| 2.0482607103584685|-0.24018143979725687| -2.276307089473967|-0.45952605749182984|-2.4022814024487804| 0.8507423618624097|-0.6772778673997445|0.002392871837704...|-1.5945587976247015|0.2742188101319103|      false|      2|       2.0|0.9581429016235314|
+-----+-----------+-------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+--------------------+-------------------+------------------+-----------+-------+----------+------------------+
only showing top 3 rows

Addressable customers predictions based on cluster centroids correlation
+-----+-----------+-------------------+-------------------+------------------+-------------------+--------------------+------------------+--------------------+------------------+-------------------+-------------------+-----------+----------+-------------------+
|   id|       type|             feat_0|             feat_1|            feat_2|             feat_3|              feat_4|            feat_5|              feat_6|            feat_7|             feat_8|             feat_9|addressable|prediction|              score|
+-----+-----------+-------------------+-------------------+------------------+-------------------+--------------------+------------------+--------------------+------------------+-------------------+-------------------+-----------+----------+-------------------+
| 1433|SUPERMARKET|-1.1893515831276014| 0.6553077741865874|0.5321357217411693|-0.9047329371040032|-0.37661701735138003|0.5083153079668359|0.025418846862648663|0.1456625868925321| 0.5212049683828303| -2.000341126780125|       true|       5.0|  0.462070331713762|
| 2442| RESTAURANT|  2.382089613949935|-1.0133302366233687|1.7104961717111093| 1.2259220449874677|   2.433062781419892|1.6819553878726934| -2.9670982691420624|1.6123003074528977| 0.5766596624272641|  4.072812737600505|       true|       9.0| 0.5860626380065698|
|12895| RESTAURANT|-1.6205384494374864| 0.7065382282683983|0.2315698097419313|0.17190788224878797|-0.12990643696007778|-0.624523559301969|-0.33374116319943786|0.3999247420045181|-0.3706415508642893|-0.5975042300871813|       true|       9.0|0.44593600329850275|
+-----+-----------+-------------------+-------------------+------------------+-------------------+--------------------+------------------+--------------------+------------------+-------------------+-------------------+-----------+----------+-------------------+
only showing top 3 rows

Off-line evaluation

In [43]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

def evaluate_multiclass_model(predictions_data, label_col, prediction_col='prediction', metric_name='accuracy'):
    evaluator = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col, metricName=metric_name)
    score = evaluator.evaluate(predictions_data)
    return score


eval_metrics = ['accuracy', 'weightedPrecision', 'weightedRecall']
label_col = 'cluster'
prediction_col = 'prediction'

for metric in eval_metrics:
    global_train_score = evaluate_multiclass_model(training, label_col, prediction_col)
    print('{metric} for training : {global_train_score}'.format(metric=metric,global_train_score=global_train_score))
accuracy for training : 0.7915
weightedPrecision for training : 0.7915
weightedRecall for training : 0.7915

Sanity check

In [44]:
training.groupBy('cluster').agg({x: "avg" for x in feature_cols}).show()
training.groupBy('prediction').agg({x: "avg" for x in feature_cols}).show()
testing.groupBy('prediction').agg({x: "avg" for x in feature_cols}).show()
+-------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+
|cluster|         avg(feat_6)|         avg(feat_0)|         avg(feat_5)|        avg(feat_9)|         avg(feat_1)|         avg(feat_8)|         avg(feat_2)|        avg(feat_7)|         avg(feat_3)|         avg(feat_4)|
+-------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+
|      1|   1.071574286515074| 0.21605266768896084|  0.5551523755368658|-1.8737561779709861|  0.6900107746633927| -1.6755056072952657| -2.5965669304926813|-1.8291142452322262|-0.01465062616213...|  -2.211940684002932|
|      6|  0.9530697457880998|  0.9934827839149837| -0.7607790085081501|-0.8023102186855207| -1.2079995530926546|  1.7890353976005635|  0.6641751210547338| 1.3180120221466154| -2.6768528612860982| -1.0475675756106295|
|      3|0.013464580222680263|  0.9436480585337791|   1.452973827991778|-0.7545054183700018|-0.06947847396461303| 0.37250030421283614|-0.13191004993350847|-0.4755488209095418| -0.5318380242889723|  -0.221350390819194|
|      5|  0.6760385997465255| -0.8515171915772918|    1.60890790514618| -2.257476609627945|  1.2275880870544447| -0.6561947540792081| -0.8798420052027012|-2.1436723235982895|  0.6513500926048165|-0.16926264759212964|
|      9| -1.1197114837679831| -0.7026757798203555| 0.19842049106324716| 0.8958684006926257| 0.14784233527432147|   0.613265134659067|   1.580021793253161| 0.8209173994106289|  0.5993147444736406|  1.6065229081498575|
|      4| -0.1426437330802988|-0.45267445679244056|  1.9383844318770038|-1.9741226514590224|  1.7619061057773946| -2.7676579441066496|  -2.873915628257285|-2.4778556769790914|   1.377954361628999|  -1.763847997082497|
|      8|  1.1800557131925509|  0.9662402990299142|  0.3410774180213643|-1.2547284008317159|-0.28106304471865984|0.024578569654849437| -1.1204947976876338|-0.8078053860328639| -1.0250541252564491|  -1.355050616545594|
|      7| -0.5135509564336395|  0.7845450778267281|  1.7920161562863062|-0.8320393062961082|  0.6638517279184281| -1.4433851396267539| -1.7280421208939987|-1.1267391091962176|  0.3853244469596441| -1.2766321063228416|
|     10|  0.9514677698111426|  2.1058833166656514|-0.13312366918466922|-0.5334553126959164| -0.6532419173994254| -1.1350901222948304|  -2.659359837165637|-0.4672943712963798|  -1.268261448451162|  -3.004586167983706|
|      2| -0.6372513327243693|   1.849130735241909|   1.242126886722189|-0.1837833123418398|  0.2395227446629334|  -2.377789177169931| -3.0629194092182455|-0.7271669896678767| 0.02358793401626674|  -2.803155881498008|
|      0|-0.20572185578899876|   2.661579319200235|  0.6826192729005873| 0.4408805219938419| -0.9997742821987676|-0.23688608260579183| -1.3051392395782906|0.47996977929085866| -1.3343477534795685| -1.8567446578651883|
+-------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+

+----------+--------------------+-------------------+--------------------+-------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+
|prediction|         avg(feat_6)|        avg(feat_0)|         avg(feat_5)|        avg(feat_9)|         avg(feat_1)|         avg(feat_8)|        avg(feat_2)|         avg(feat_7)|         avg(feat_3)|         avg(feat_4)|
+----------+--------------------+-------------------+--------------------+-------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+
|       8.0|  1.3840841677084437| 1.0131979573837016|  0.2529830180922895| -1.345470898511208|-0.35727447514258404| 0.13507499238958487|-1.1262184607025287| -0.8605085007634882| -1.1465475869221047| -1.3920581140910933|
|       0.0| -0.1628742760389539| 2.6672853855661636|  0.6222503190749948|0.44634414410111084| -1.0382827238131145|-0.18059005408066547|-1.2694741983579347|   0.514599391236229| -1.3851939932763104|  -1.852404686129519|
|       7.0| -0.6862422495906356| 0.7935212654066869|  1.9777809031639415|-0.8086390748112348|  0.7639194865671812| -1.6059625182779396|   -1.8176492645861| -1.1817197810114335|  0.5242226879581598| -1.2913717583273772|
|       1.0|   1.019044119312481| 0.2757193667402122|  0.6123533375609794| -1.799462645404406|  0.6293478976178691| -1.5251484820760115|  -2.42878269603448| -1.7469670673879187|-0.05521656348342101|  -2.072742198351179|
|       4.0|-0.14489723462481852|-0.3574293764949105|  1.9335479217997762|-1.8497710203708722|  1.5927723517807537|   -2.38689870395991|-2.4971601309754226| -2.2847304650290035|  1.2206958883145673| -1.4884939157332886|
|       3.0| -0.0481420355223652| 0.9755575798467644|   1.588815628533787|-0.7469485874290014|-0.06181585369465211|  0.4332820419445692|-0.0476143041314624| -0.4884598671514655| -0.5079321861326859| -0.1160576945136141|
|       2.0| -0.5991834483351283|  1.747451715229772|   1.250897881924546|-0.1999848305082562| 0.17135297888394305|  -1.984259759467475|-2.6265006987542665| -0.6168753707595849|-0.07215827954639607| -2.4436649107691446|
|      10.0|  0.8616953939063546| 1.9630921317431795|0.012107763119531077|-0.5456952142409301| -0.5823286138725119| -0.9990145568816777|-2.4073985937617195|-0.48477983639493355| -1.1611390271726274| -2.7055202600484844|
|       6.0|  0.9494495820146144| 0.9323921667334581|  -0.622732274307123|-0.7688850635333763| -1.1092522761293826|  1.6990360038614063| 0.6409293976504327|  1.0923198469978295|  -2.413266661287432| -0.8584586906326653|
|       5.0|   0.709825870445307|-0.8348635352415527|   1.511381832770946| -2.241927433040746|  1.1490115528875304| -0.5265526209178581|-0.7767150027386268|  -2.021069450729347|  0.5216369817642373|-0.16423447969124377|
|       9.0| -1.3235992844943076|-0.8365706035195388| 0.24756844082767945|  1.032475445568816|  0.2547196757702479|   0.490712880081246| 1.6279128035194121|  0.8436553839189556|  0.8204646640372271|  1.7452488723771116|
+----------+--------------------+-------------------+--------------------+-------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+

+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+
|prediction|         avg(feat_6)|         avg(feat_0)|         avg(feat_5)|         avg(feat_9)|         avg(feat_1)|         avg(feat_8)|         avg(feat_2)|         avg(feat_7)|         avg(feat_3)|        avg(feat_4)|
+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+
|       8.0|  1.9637665018380983|  1.1049584892089819|-0.21046099659555806| -1.4791628831568815|   -0.59666853780767|  0.3503656064224171|  -1.213485672777509| -0.9278726273337795| -1.4617866028170523|-1.5749920177614958|
|       0.0|-0.06707149041719072|   2.424298614966093| 0.33232577854845746|  1.2029576660163015| -1.2042998332174095| 0.33473435736878177|-0.43405719024945266| 0.47540667775486123| -0.7994144377998047|-0.5756002157419436|
|       7.0|  -0.645694715460231|  0.8619514479273754|  1.9182608278934272| -0.7164692025428567|    0.69410511869723| -1.5565868516363548| -1.7909136341470768| -1.1613252321687937|  0.5143001126412066|-1.2507420871483876|
|       1.0|     1.0488145504199|  0.3387940384448437|  0.6286088730599976| -1.8046030513328555|  0.6161727277396036| -1.5605967459160575|  -2.503020278665543| -1.7889011616658634|-0.06094532021667432| -2.128742794465861|
|       4.0| -0.2069000645188308|  -0.483642534074587|  1.7855003751207292| -1.7508827870500954|  1.6193170001060027| -2.4486334620585324|  -2.470678077041857| -2.1946696867414075|  1.2912829126239969|-1.4652251135024505|
|       3.0| -0.2647594432942133|  0.7889412142422555|  1.5776503682918819| -0.5886904267056541|-0.04817999601773439|  0.6303220922655197| 0.34974567623326847|-0.25481213838221844| -0.4496023742060965|0.22624804521792538|
|       2.0| -0.4946958006161795|  1.7386437847782275|  1.1674116957474943|-0.23755649870685175|  0.1662157342709808|  -2.028117501841253| -2.7169936142983304| -0.6779963101580613|-0.07708735481900951|  -2.51613165745582|
|      10.0|  1.5410659987571098|   2.286095404279387| -0.5548340511577113|-0.16604511098057467| -1.0704244413302686| -0.5059895611874382| -2.1979347569391505| -0.5714159664633033|  -1.286632946628108|-2.3628612659485357|
|       6.0|   0.729033095136791|-0.10881018237241187|  -1.144796666480684| -1.4329740277204428| -0.7608118708078342|  1.7204763664973013|  0.9354415315420096|   1.750402452919789| -2.8629368879070296|-1.2598936695641316|
|       5.0|  0.6688134057851016| -1.9044192484906157| 0.13334221176351788| -1.4970790742707092|  1.1250092391570354|-0.32219880216580116| 0.01725658955215726|  -1.333910020194394|  0.8993911215982188| 0.5796969208279608|
|       9.0| -1.5928112889752555| -0.7947613458589857|  0.5522173642959225|   1.042991808644455| 0.28182977495382744|  0.6336689766647088|  1.8667058867399735|  0.9776743904658439|  0.7877562254652667| 1.9137936752411435|
+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+